ziya 0.1.49__py3-none-any.whl → 0.1.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ziya might be problematic. Click here for more details.
- app/agents/.agent.py.swp +0 -0
- app/agents/agent.py +315 -113
- app/agents/models.py +439 -0
- app/agents/prompts.py +32 -4
- app/main.py +70 -7
- app/server.py +403 -14
- app/utils/code_util.py +641 -215
- pyproject.toml +3 -3
- templates/asset-manifest.json +18 -20
- templates/index.html +1 -1
- templates/static/css/{main.87f30840.css → main.2bddf34e.css} +2 -2
- templates/static/css/main.2bddf34e.css.map +1 -0
- templates/static/js/46907.90c6a4f3.chunk.js +2 -0
- templates/static/js/46907.90c6a4f3.chunk.js.map +1 -0
- templates/static/js/56122.1d6a5c10.chunk.js +3 -0
- templates/static/js/56122.1d6a5c10.chunk.js.LICENSE.txt +9 -0
- templates/static/js/56122.1d6a5c10.chunk.js.map +1 -0
- templates/static/js/83953.61a908f4.chunk.js +3 -0
- templates/static/js/83953.61a908f4.chunk.js.map +1 -0
- templates/static/js/88261.1e90079d.chunk.js +3 -0
- templates/static/js/88261.1e90079d.chunk.js.map +1 -0
- templates/static/js/{96603.863a8f96.chunk.js → 96603.18c5d644.chunk.js} +2 -2
- templates/static/js/{96603.863a8f96.chunk.js.map → 96603.18c5d644.chunk.js.map} +1 -1
- templates/static/js/{97902.75670155.chunk.js → 97902.d1e262d6.chunk.js} +3 -3
- templates/static/js/{97902.75670155.chunk.js.map → 97902.d1e262d6.chunk.js.map} +1 -1
- templates/static/js/main.9b2b2b57.js +3 -0
- templates/static/js/{main.ee8b3c96.js.LICENSE.txt → main.9b2b2b57.js.LICENSE.txt} +8 -2
- templates/static/js/main.9b2b2b57.js.map +1 -0
- {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/METADATA +5 -5
- {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/RECORD +36 -35
- templates/static/css/main.87f30840.css.map +0 -1
- templates/static/js/23416.c33f07ab.chunk.js +0 -3
- templates/static/js/23416.c33f07ab.chunk.js.map +0 -1
- templates/static/js/3799.fedb612f.chunk.js +0 -2
- templates/static/js/3799.fedb612f.chunk.js.map +0 -1
- templates/static/js/46907.4a730107.chunk.js +0 -2
- templates/static/js/46907.4a730107.chunk.js.map +0 -1
- templates/static/js/64754.cf383335.chunk.js +0 -2
- templates/static/js/64754.cf383335.chunk.js.map +0 -1
- templates/static/js/88261.33450351.chunk.js +0 -3
- templates/static/js/88261.33450351.chunk.js.map +0 -1
- templates/static/js/main.ee8b3c96.js +0 -3
- templates/static/js/main.ee8b3c96.js.map +0 -1
- /templates/static/js/{23416.c33f07ab.chunk.js.LICENSE.txt → 83953.61a908f4.chunk.js.LICENSE.txt} +0 -0
- /templates/static/js/{88261.33450351.chunk.js.LICENSE.txt → 88261.1e90079d.chunk.js.LICENSE.txt} +0 -0
- /templates/static/js/{97902.75670155.chunk.js.LICENSE.txt → 97902.d1e262d6.chunk.js.LICENSE.txt} +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/LICENSE +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/WHEEL +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.51.dist-info}/entry_points.txt +0 -0
app/agents/models.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import json
|
|
4
|
+
import botocore
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from langchain_aws import ChatBedrock
|
|
7
|
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
8
|
+
from langchain_core.language_models import BaseChatModel
|
|
9
|
+
from langchain.callbacks.base import BaseCallbackHandler
|
|
10
|
+
from app.utils.logging_utils import logger
|
|
11
|
+
import google.auth.exceptions
|
|
12
|
+
import google.auth
|
|
13
|
+
from dotenv import load_dotenv
|
|
14
|
+
from dotenv.main import find_dotenv
|
|
15
|
+
|
|
16
|
+
class ModelManager:
|
|
17
|
+
|
|
18
|
+
# Class-level state with process-specific initialization
|
|
19
|
+
_state = {
|
|
20
|
+
'model': None,
|
|
21
|
+
'auth_checked': False,
|
|
22
|
+
'auth_success': False,
|
|
23
|
+
'google_credentials': None,
|
|
24
|
+
'aws_profile': None,
|
|
25
|
+
'process_id': None
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
DEFAULT_ENDPOINT = "bedrock"
|
|
29
|
+
|
|
30
|
+
DEFAULT_MODELS = {
|
|
31
|
+
"bedrock": "sonnet3.5-v2",
|
|
32
|
+
"google": "gemini-1.5-pro"
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
MODEL_CONFIGS = {
|
|
37
|
+
"bedrock": {
|
|
38
|
+
"sonnet3.7": {
|
|
39
|
+
"model_id": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
40
|
+
"token_limit": 200000,
|
|
41
|
+
"max_output_tokens": 128000,
|
|
42
|
+
"temperature": 0.3,
|
|
43
|
+
"top_k": 15,
|
|
44
|
+
"supports_thinking": True,
|
|
45
|
+
},
|
|
46
|
+
"sonnet3.5-v2": {
|
|
47
|
+
"model_id": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
48
|
+
"token_limit": 200000,
|
|
49
|
+
"max_output_tokens": 4096,
|
|
50
|
+
"temperature": 0.3,
|
|
51
|
+
"top_k": 15,
|
|
52
|
+
},
|
|
53
|
+
"sonnet3.5": {
|
|
54
|
+
"model_id": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
55
|
+
"token_limit": 200000,
|
|
56
|
+
"max_output_tokens": 4096,
|
|
57
|
+
"temperature": 0.3,
|
|
58
|
+
"top_k": 15,
|
|
59
|
+
},
|
|
60
|
+
"opus": {
|
|
61
|
+
"model_id": "us.anthropic.claude-3-opus-20240229-v1:0",
|
|
62
|
+
"token_limit": 200000,
|
|
63
|
+
"max_output_tokens": 4096,
|
|
64
|
+
"temperature": 0.3,
|
|
65
|
+
"top_k": 15,
|
|
66
|
+
},
|
|
67
|
+
"sonnet": {
|
|
68
|
+
"model_id": "us.anthropic.claude-3-sonnet-20240229-v1:0",
|
|
69
|
+
"token_limit": 200000,
|
|
70
|
+
"max_output_tokens": 4096,
|
|
71
|
+
"temperature": 0.3,
|
|
72
|
+
"top_k": 15,
|
|
73
|
+
},
|
|
74
|
+
"haiku": {
|
|
75
|
+
"model_id": "us.anthropic.claude-3-haiku-20240307-v1:0",
|
|
76
|
+
"token_limit": 200000,
|
|
77
|
+
"max_output_tokens": 4096,
|
|
78
|
+
"temperature": 0.3,
|
|
79
|
+
"top_k": 15,
|
|
80
|
+
},
|
|
81
|
+
},
|
|
82
|
+
"google": {
|
|
83
|
+
"gemini-pro": {
|
|
84
|
+
"model_id": "gemini-pro",
|
|
85
|
+
"token_limit": 30720,
|
|
86
|
+
"max_output_tokens": 2048,
|
|
87
|
+
"temperature": 0.3,
|
|
88
|
+
"convert_system_message_to_human": True,
|
|
89
|
+
"streaming": False,
|
|
90
|
+
},
|
|
91
|
+
"gemini-1.5-pro": {
|
|
92
|
+
"model_id": "gemini-1.5-pro",
|
|
93
|
+
"token_limit": 1000000,
|
|
94
|
+
"max_output_tokens": 2048,
|
|
95
|
+
"temperature": 0.3,
|
|
96
|
+
"convert_system_message_to_human": False,
|
|
97
|
+
"streaming": False,
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
@classmethod
|
|
103
|
+
def get_model_config(cls, endpoint: str, model_name: str = None) -> dict:
|
|
104
|
+
"""
|
|
105
|
+
Get the configuration for a specific model. If model_name is None,
|
|
106
|
+
returns the default model for the endpoint.
|
|
107
|
+
"""
|
|
108
|
+
endpoint_configs = cls.MODEL_CONFIGS.get(endpoint)
|
|
109
|
+
if not endpoint_configs:
|
|
110
|
+
raise ValueError(f"Invalid endpoint: {endpoint}")
|
|
111
|
+
|
|
112
|
+
if model_name is None:
|
|
113
|
+
default_name = cls.DEFAULT_MODELS[endpoint]
|
|
114
|
+
return {**endpoint_configs[default_name], "name": default_name}
|
|
115
|
+
|
|
116
|
+
# Check if it's a model ID
|
|
117
|
+
for name, config in endpoint_configs.items():
|
|
118
|
+
if config["model_id"] == model_name:
|
|
119
|
+
return {**config, "name": name}
|
|
120
|
+
|
|
121
|
+
# Check if it's a model name
|
|
122
|
+
if model_name in endpoint_configs:
|
|
123
|
+
return {**endpoint_configs[model_name], "name": model_name}
|
|
124
|
+
|
|
125
|
+
# Neither - show valid options
|
|
126
|
+
valid_models = ", ".join(endpoint_configs.keys())
|
|
127
|
+
if endpoint == "bedrock":
|
|
128
|
+
raise ValueError(
|
|
129
|
+
f"Invalid model '{model_name}' for bedrock endpoint. "
|
|
130
|
+
f"Valid models are: {valid_models}"
|
|
131
|
+
)
|
|
132
|
+
elif endpoint == "google":
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"Invalid model '{model_name}' for google endpoint. "
|
|
135
|
+
f"Valid models are: {valid_models}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def _load_credentials(cls) -> bool:
|
|
140
|
+
"""
|
|
141
|
+
Load credentials from environment or .env files.
|
|
142
|
+
Returns True if GOOGLE_API_KEY is found, False otherwise.
|
|
143
|
+
"""
|
|
144
|
+
current_pid = os.getpid()
|
|
145
|
+
|
|
146
|
+
# Check if we've already loaded credentials in this process
|
|
147
|
+
if cls._state['auth_checked'] and cls._state['process_id'] == current_pid:
|
|
148
|
+
return bool(cls._state['google_credentials'])
|
|
149
|
+
|
|
150
|
+
# Reset state for new process
|
|
151
|
+
if cls._state['process_id'] != current_pid:
|
|
152
|
+
cls._state['credentials_checked'] = False
|
|
153
|
+
|
|
154
|
+
cwd = os.getcwd()
|
|
155
|
+
home = str(Path.home())
|
|
156
|
+
env_locations = {
|
|
157
|
+
'current_dir': os.path.join(cwd, '.env'),
|
|
158
|
+
'home_ziya': os.path.join(home, '.ziya', '.env'),
|
|
159
|
+
'found_dotenv': find_dotenv()
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
logger.debug("Searching for .env files:")
|
|
163
|
+
for location_name, env_file in env_locations.items():
|
|
164
|
+
if os.path.exists(env_file):
|
|
165
|
+
logger.info(f"Loading credentials from {location_name}: {env_file}")
|
|
166
|
+
try:
|
|
167
|
+
with open(env_file, 'r') as f:
|
|
168
|
+
logger.debug(f"Content of {env_file}:")
|
|
169
|
+
for line in f:
|
|
170
|
+
logger.debug(f" {line.rstrip()}")
|
|
171
|
+
except Exception as e:
|
|
172
|
+
logger.error(f"Error reading {env_file}: {e}")
|
|
173
|
+
else:
|
|
174
|
+
logger.debug(f"No .env file at {location_name}: {env_file}")
|
|
175
|
+
|
|
176
|
+
for env_file in env_locations.values():
|
|
177
|
+
cls._state['auth_checked'] = True
|
|
178
|
+
cls._google_credentials = os.getenv("GOOGLE_API_KEY")
|
|
179
|
+
if os.path.exists(env_file):
|
|
180
|
+
logger.info(f"Loading credentials from {env_file}")
|
|
181
|
+
success = load_dotenv(env_file, override=True)
|
|
182
|
+
if success:
|
|
183
|
+
# Explicitly store the value we loaded
|
|
184
|
+
api_key = os.getenv("GOOGLE_API_KEY")
|
|
185
|
+
if api_key:
|
|
186
|
+
cls._state.update({
|
|
187
|
+
'auth_checked': True,
|
|
188
|
+
'auth_success': True,
|
|
189
|
+
'google_credentials': os.getenv("GOOGLE_API_KEY"),
|
|
190
|
+
'process_id': current_pid
|
|
191
|
+
})
|
|
192
|
+
return True
|
|
193
|
+
else:
|
|
194
|
+
logger.warning(f"Found .env file at {location_name}: {env_file} but it doesn't contain GOOGLE_API_KEY")
|
|
195
|
+
else:
|
|
196
|
+
if "GOOGLE_API_KEY" not in os.environ:
|
|
197
|
+
logger.debug("No .env file found, using system environment variables")
|
|
198
|
+
cls._state.update({
|
|
199
|
+
'auth_checked': True,
|
|
200
|
+
'auth_success': True,
|
|
201
|
+
'google_credentials': os.getenv("GOOGLE_API_KEY"),
|
|
202
|
+
'process_id': current_pid
|
|
203
|
+
})
|
|
204
|
+
return bool(cls._state['google_credentials'])
|
|
205
|
+
|
|
206
|
+
@classmethod
|
|
207
|
+
def initialize_model(cls, force_reinit: bool = False) -> BaseChatModel:
|
|
208
|
+
|
|
209
|
+
"""Initialize and return the appropriate model based on environment settings."""
|
|
210
|
+
current_pid = os.getpid()
|
|
211
|
+
|
|
212
|
+
# Return cached model if it exists for this process
|
|
213
|
+
if cls._state['model'] is not None and cls._state['process_id'] == current_pid:
|
|
214
|
+
return cls._state['model']
|
|
215
|
+
|
|
216
|
+
# Reset state for new process if needed
|
|
217
|
+
if cls._state['process_id'] != current_pid:
|
|
218
|
+
cls._state['model'] = None
|
|
219
|
+
cls._state['auth_checked'] = False
|
|
220
|
+
|
|
221
|
+
endpoint = os.environ.get("ZIYA_ENDPOINT", cls.DEFAULT_ENDPOINT)
|
|
222
|
+
model_name = os.environ.get("ZIYA_MODEL")
|
|
223
|
+
|
|
224
|
+
# Clear existing model if forcing reinitialization
|
|
225
|
+
if force_reinit:
|
|
226
|
+
if cls._state['model']:
|
|
227
|
+
# Cleanup if needed
|
|
228
|
+
cls._state['model'] = None
|
|
229
|
+
cls._state['auth_checked'] = False
|
|
230
|
+
|
|
231
|
+
logger.info(f"Initializing model for endpoint: {endpoint}, model: {model_name}")
|
|
232
|
+
if endpoint == "bedrock":
|
|
233
|
+
cls._state['model'] = cls._initialize_bedrock_model(model_name)
|
|
234
|
+
# Don't override the model_id with the alias name
|
|
235
|
+
# if model_name:
|
|
236
|
+
# cls._state['model'].model_id = model_name
|
|
237
|
+
elif endpoint == "google":
|
|
238
|
+
cls._state['model'] = cls._initialize_google_model(model_name)
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"Unsupported endpoint: {endpoint}")
|
|
241
|
+
|
|
242
|
+
# Update process ID after successful initialization
|
|
243
|
+
cls._state['process_id'] = current_pid
|
|
244
|
+
|
|
245
|
+
return cls._state['model']
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def _initialize_bedrock_model(cls, model_name: Optional[str] = None) -> ChatBedrock:
|
|
249
|
+
"""Initialize a Bedrock model."""
|
|
250
|
+
config = cls.get_model_config("bedrock", model_name)
|
|
251
|
+
model_id = config["model_id"]
|
|
252
|
+
#model_id = model_name if model_name else model_id # Use provided model name if available
|
|
253
|
+
max_output = config.get('max_output_tokens', 4096)
|
|
254
|
+
|
|
255
|
+
if not cls._state['aws_profile']:
|
|
256
|
+
cls._state['aws_profile'] = os.environ.get("ZIYA_AWS_PROFILE")
|
|
257
|
+
cls._state['aws_region'] = os.environ.get("ZIYA_AWS_REGION", "us-west-2")
|
|
258
|
+
|
|
259
|
+
logger.info(f"Using AWS Profile: {cls._state['aws_profile']}" if cls._state['aws_profile'] else "Using default AWS credentials")
|
|
260
|
+
|
|
261
|
+
# Get custom settings if available
|
|
262
|
+
temperature = float(os.environ.get("ZIYA_TEMPERATURE", config.get('temperature', 0.3)))
|
|
263
|
+
top_k = int(os.environ.get("ZIYA_TOP_K", config.get('top_k', 15)))
|
|
264
|
+
max_output = int(os.environ.get("ZIYA_MAX_OUTPUT_TOKENS", config.get('max_output_tokens', 4096)))
|
|
265
|
+
|
|
266
|
+
logger.info(f"Initializing Bedrock model: {model_id} with max_tokens: {max_output}, "
|
|
267
|
+
f"temperature: {temperature}, top_k: {top_k}")
|
|
268
|
+
|
|
269
|
+
return ChatBedrock(
|
|
270
|
+
model_id=model_id,
|
|
271
|
+
credentials_profile_name=cls._state['aws_profile'],
|
|
272
|
+
region_name=cls._state['aws_region'],
|
|
273
|
+
|
|
274
|
+
config=botocore.config.Config(read_timeout=900, retries={'max_attempts': 3, 'total_max_attempts': 5}),
|
|
275
|
+
model_kwargs={
|
|
276
|
+
"max_tokens": max_output,
|
|
277
|
+
"temperature": temperature,
|
|
278
|
+
"top_k": top_k
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
@classmethod
|
|
283
|
+
def _initialize_google_model(cls, model_name: Optional[str] = None) -> ChatGoogleGenerativeAI:
|
|
284
|
+
"""Initialize a Google model."""
|
|
285
|
+
if not model_name:
|
|
286
|
+
model_name = "gemini-1.5-pro"
|
|
287
|
+
config = cls.get_model_config("google", model_name)
|
|
288
|
+
# Load credentials if not already loaded
|
|
289
|
+
if not cls._state['auth_checked']:
|
|
290
|
+
if not cls._load_credentials():
|
|
291
|
+
raise ValueError(
|
|
292
|
+
"GOOGLE_API_KEY environment variable is required for google endpoint.\n"
|
|
293
|
+
"You can set it in your environment or create a .env file in either:\n"
|
|
294
|
+
" - Your current directory\n"
|
|
295
|
+
" - ~/.ziya/.env\n")
|
|
296
|
+
|
|
297
|
+
api_key = os.getenv("GOOGLE_API_KEY")
|
|
298
|
+
if api_key:
|
|
299
|
+
logger.debug(f"Found API key (starts with: {api_key[:6]}...)")
|
|
300
|
+
if not api_key.startswith("AI"):
|
|
301
|
+
logger.warning(f"API key format looks incorrect (starts with '{api_key[:6]}', should start with 'AI')")
|
|
302
|
+
else:
|
|
303
|
+
logger.debug("No API key found in environment")
|
|
304
|
+
|
|
305
|
+
# Check Application Default Credentials
|
|
306
|
+
try:
|
|
307
|
+
credentials, project = google.auth.default()
|
|
308
|
+
logger.debug(f"Found ADC credentials (project: {project})")
|
|
309
|
+
except Exception as e:
|
|
310
|
+
logger.debug(f"No ADC credentials found: {str(e)}")
|
|
311
|
+
credentials = None
|
|
312
|
+
project = None
|
|
313
|
+
|
|
314
|
+
logger.info(f"Attempting to initialize Google model: {model_name}")
|
|
315
|
+
|
|
316
|
+
# Get the model config
|
|
317
|
+
model_config = cls.get_model_config("google", model_name)
|
|
318
|
+
logger.info(f"Using model config: {json.dumps(model_config, indent=2)}")
|
|
319
|
+
|
|
320
|
+
# Extract parameters from config and override with environment settings if available
|
|
321
|
+
convert_system = model_config.get("convert_system_message_to_human", True)
|
|
322
|
+
temperature = float(os.environ.get("ZIYA_TEMPERATURE",
|
|
323
|
+
model_config.get("temperature", 0.3)))
|
|
324
|
+
# Get custom settings if available
|
|
325
|
+
top_k = int(os.environ.get("ZIYA_TOP_K", model_config.get("top_k", 0)))
|
|
326
|
+
max_output_tokens = model_config.get("max_output_tokens", 2048)
|
|
327
|
+
|
|
328
|
+
# Use our custom wrapper class instead of ChatGoogleGenerativeAI directly
|
|
329
|
+
model = SafeChatGoogleGenerativeAI(
|
|
330
|
+
model=model_config["model_id"],
|
|
331
|
+
convert_system_message_to_human=convert_system,
|
|
332
|
+
temperature=temperature,
|
|
333
|
+
max_output_tokens=max_output_tokens,
|
|
334
|
+
top_k=top_k,
|
|
335
|
+
client_options={"api_endpoint": "generativelanguage.googleapis.com"},
|
|
336
|
+
max_retries=3,
|
|
337
|
+
verbose=os.environ.get("ZIYA_THINKING_MODE") == "1"
|
|
338
|
+
)
|
|
339
|
+
model.callbacks = [EmptyMessageFilter()]
|
|
340
|
+
logger.info("Successfully connected to Google API")
|
|
341
|
+
return model
|
|
342
|
+
|
|
343
|
+
except google.auth.exceptions.DefaultCredentialsError as e:
|
|
344
|
+
logger.error(f"Authentication error details: {str(e)}")
|
|
345
|
+
raise ValueError(
|
|
346
|
+
"\nGoogle API authentication failed. You need to either:\n\n"
|
|
347
|
+
"1. Use an API key (recommended for testing):\n"
|
|
348
|
+
" - Get an API key from: https://makersuite.google.com/app/apikey\n"
|
|
349
|
+
" - Add to .env file: GOOGLE_API_KEY=your_key_here\n"
|
|
350
|
+
f" Current API key status: {'Found' if api_key else 'Not found'}\n\n"
|
|
351
|
+
"2. Or set up Application Default Credentials (for production):\n"
|
|
352
|
+
" - Install gcloud CLI: https://cloud.google.com/sdk/docs/install\n"
|
|
353
|
+
" - Run: gcloud auth application-default login\n"
|
|
354
|
+
" - See: https://cloud.google.com/docs/authentication/external/set-up-adc\n"
|
|
355
|
+
f" Current ADC status: {'Found' if credentials else 'Not found'}\n\n"
|
|
356
|
+
"Choose option 1 (API key) if you're just getting started.\n"
|
|
357
|
+
)
|
|
358
|
+
except Exception as e:
|
|
359
|
+
logger.error(f"Unexpected error initializing Google model: {str(e)}")
|
|
360
|
+
raise ValueError(
|
|
361
|
+
f"\nFailed to initialize Google model: {str(e)}\n\n"
|
|
362
|
+
f"API key status: {'Found' if api_key else 'Not found'}\n"
|
|
363
|
+
f"ADC status: {'Found' if credentials else 'Not found'}\n"
|
|
364
|
+
"Please check your credentials and try again."
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
@classmethod
|
|
368
|
+
def get_available_models(cls, endpoint: Optional[str] = None) -> list[str]:
|
|
369
|
+
"""Get list of available models for the specified endpoint."""
|
|
370
|
+
if endpoint is None:
|
|
371
|
+
endpoint = os.environ.get("ZIYA_ENDPOINT", "bedrock")
|
|
372
|
+
|
|
373
|
+
if endpoint == "bedrock":
|
|
374
|
+
return list(cls.BEDROCK_MODELS.keys())
|
|
375
|
+
elif endpoint == "google":
|
|
376
|
+
return cls.GOOGLE_MODELS
|
|
377
|
+
else:
|
|
378
|
+
raise ValueError(f"Unsupported endpoint: {endpoint}")
|
|
379
|
+
|
|
380
|
+
class EmptyMessageFilter(BaseCallbackHandler):
|
|
381
|
+
"""Filter out empty messages before they reach the Gemini API."""
|
|
382
|
+
|
|
383
|
+
def on_chat_model_start(self, serialized, messages, **kwargs):
|
|
384
|
+
"""Check messages before they're sent to the model."""
|
|
385
|
+
for i, message in enumerate(messages):
|
|
386
|
+
if hasattr(message, 'content'):
|
|
387
|
+
# If content is empty, replace with a placeholder
|
|
388
|
+
if not message.content or message.content.strip() == '':
|
|
389
|
+
logger.warning(f"Empty message detected in position {i}, replacing with placeholder")
|
|
390
|
+
message.content = "Please provide a question."
|
|
391
|
+
|
|
392
|
+
# Handle messages with dict content
|
|
393
|
+
if isinstance(message, dict) and 'content' in message:
|
|
394
|
+
if not message['content'] or message['content'].strip() == '':
|
|
395
|
+
logger.warning(f"Empty dict message detected in position {i}, replacing with placeholder")
|
|
396
|
+
message['content'] = "Please provide a question."
|
|
397
|
+
|
|
398
|
+
# Check if all messages are empty
|
|
399
|
+
if not any(getattr(m, 'content', None) or
|
|
400
|
+
(isinstance(m, dict) and m.get('content'))
|
|
401
|
+
for m in messages):
|
|
402
|
+
logger.error("All messages are empty, adding a placeholder message")
|
|
403
|
+
messages.append({"role": "user", "content": "Please provide a question."})
|
|
404
|
+
return messages
|
|
405
|
+
|
|
406
|
+
# Create a custom wrapper class for ChatGoogleGenerativeAI
|
|
407
|
+
class SafeChatGoogleGenerativeAI(ChatGoogleGenerativeAI):
|
|
408
|
+
"""A wrapper around ChatGoogleGenerativeAI that prevents empty messages."""
|
|
409
|
+
|
|
410
|
+
def _validate_messages(self, messages):
|
|
411
|
+
"""Ensure no messages have empty content."""
|
|
412
|
+
logger.info(f"Validating {len(messages)} messages")
|
|
413
|
+
for i, msg in enumerate(messages):
|
|
414
|
+
if hasattr(msg, 'content'):
|
|
415
|
+
if not msg.content or msg.content.strip() == '':
|
|
416
|
+
logger.warning(f"Empty message detected at position {i}, replacing with placeholder")
|
|
417
|
+
msg.content = "Please provide a question."
|
|
418
|
+
elif isinstance(msg, dict) and 'content' in msg:
|
|
419
|
+
if not msg['content'] or not msg['content'].strip():
|
|
420
|
+
logger.warning(f"Empty dict message detected at position {i}, replacing with placeholder")
|
|
421
|
+
msg['content'] = "Please provide a question."
|
|
422
|
+
return messages
|
|
423
|
+
|
|
424
|
+
async def agenerate(self, messages, *args, **kwargs):
|
|
425
|
+
"""Override agenerate to validate messages."""
|
|
426
|
+
messages = self._validate_messages(messages)
|
|
427
|
+
return await super().agenerate(messages, *args, **kwargs)
|
|
428
|
+
|
|
429
|
+
def generate(self, messages, *args, **kwargs):
|
|
430
|
+
"""Override generate to validate messages."""
|
|
431
|
+
messages = self._validate_messages(messages)
|
|
432
|
+
return super().generate(messages, *args, **kwargs)
|
|
433
|
+
|
|
434
|
+
async def ainvoke(self, input, *args, **kwargs):
|
|
435
|
+
"""Override ainvoke to validate input."""
|
|
436
|
+
if isinstance(input, list):
|
|
437
|
+
input = self._validate_messages(input)
|
|
438
|
+
return await super().ainvoke(input, *args, **kwargs)
|
|
439
|
+
|
app/agents/prompts.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
2
2
|
# import pydevd_pycharm
|
|
3
|
+
from app.utils.logging_utils import logger
|
|
3
4
|
|
|
4
5
|
template = """
|
|
5
6
|
|
|
@@ -14,6 +15,9 @@ CRITICAL: INSTRUCTION PRESERVATION:
|
|
|
14
15
|
- If instructions seem to conflict, ask for clarification
|
|
15
16
|
|
|
16
17
|
You are an excellent coder. Help the user with their coding tasks. You are given the codebase of the user in your context.
|
|
18
|
+
surgical debug and fixes.
|
|
19
|
+
ask for clarification rather than declaring your surety unless you are absolutely certain.
|
|
20
|
+
I don't need you to be confident, I need you to be correct.
|
|
17
21
|
|
|
18
22
|
IMPORTANT: Code Context Format
|
|
19
23
|
|
|
@@ -100,7 +104,7 @@ CRITICAL: When generating hunks and context:
|
|
|
100
104
|
4. Double-check that context lines exist in the original file
|
|
101
105
|
|
|
102
106
|
CRITICAL: VISUALIZATION CAPABILITIES:
|
|
103
|
-
You can generate inline diagrams using either ```graphviz
|
|
107
|
+
You can generate inline diagrams using either ```graphviz code blocks.
|
|
104
108
|
Actively look for opportunities to enhance explanations with visual representations
|
|
105
109
|
when they would provide clearer understanding, especially for:
|
|
106
110
|
- System architectures
|
|
@@ -211,12 +215,36 @@ Remember to strictly adhere to the Git diff format guidelines provided above whe
|
|
|
211
215
|
|
|
212
216
|
"""
|
|
213
217
|
|
|
218
|
+
# Create a wrapper around the original template
|
|
219
|
+
original_template = template
|
|
220
|
+
|
|
221
|
+
def log_template_variables(variables):
|
|
222
|
+
logger.info(f"Template variables: {variables.get('question', 'EMPTY')}")
|
|
223
|
+
return original_template
|
|
224
|
+
|
|
225
|
+
# Debug function to log template variables
|
|
226
|
+
def debug_question_template(question):
|
|
227
|
+
logger.info("====== TEMPLATE QUESTION DEBUG ======")
|
|
228
|
+
logger.info(f"Question type: {type(question)}")
|
|
229
|
+
logger.info(f"Question value: '{question}'")
|
|
230
|
+
logger.info(f"Question is empty: {not question or not question.strip()}")
|
|
231
|
+
logger.info("====== END TEMPLATE QUESTION DEBUG ======")
|
|
232
|
+
return question
|
|
233
|
+
|
|
234
|
+
# Debug function to log chat history
|
|
235
|
+
def debug_chat_history(chat_history):
|
|
236
|
+
logger.info("====== TEMPLATE CHAT HISTORY DEBUG ======")
|
|
237
|
+
logger.info(f"Chat history type: {type(chat_history)}")
|
|
238
|
+
logger.info(f"Chat history length: {len(chat_history) if hasattr(chat_history, '__len__') else 'N/A'}")
|
|
239
|
+
logger.info("====== END TEMPLATE CHAT HISTORY DEBUG ======")
|
|
240
|
+
return chat_history
|
|
241
|
+
|
|
214
242
|
conversational_prompt = ChatPromptTemplate.from_messages(
|
|
215
243
|
[
|
|
216
244
|
("system", template),
|
|
217
|
-
|
|
218
|
-
|
|
245
|
+
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
|
246
|
+
|
|
219
247
|
("user", "{question}"),
|
|
220
|
-
("
|
|
248
|
+
MessagesPlaceholder(variable_name="agent_scratchpad", optional=True),
|
|
221
249
|
]
|
|
222
250
|
)
|
app/main.py
CHANGED
|
@@ -9,6 +9,8 @@ from langchain_cli.cli import serve
|
|
|
9
9
|
from app.utils.logging_utils import logger
|
|
10
10
|
from app.utils.langchain_validation_util import validate_langchain_vars
|
|
11
11
|
from app.utils.version_util import get_current_version, get_latest_version
|
|
12
|
+
from app.server import DEFAULT_PORT
|
|
13
|
+
from app.agents.models import ModelManager
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
def parse_arguments():
|
|
@@ -17,14 +19,23 @@ def parse_arguments():
|
|
|
17
19
|
help="List of files or directories to exclude (e.g., --exclude 'tst,build,*.py')")
|
|
18
20
|
parser.add_argument("--profile", type=str, default=None,
|
|
19
21
|
help="AWS profile to use (e.g., --profile ziya)")
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
22
|
+
|
|
23
|
+
# Get default model alias from ModelManager based on default endpoint
|
|
24
|
+
default_model = ModelManager.DEFAULT_MODELS[ModelManager.DEFAULT_ENDPOINT]
|
|
25
|
+
parser.add_argument("--endpoint", type=str, choices=["bedrock", "google"], default=ModelManager.DEFAULT_ENDPOINT,
|
|
26
|
+
help=f"Model endpoint to use (default: {ModelManager.DEFAULT_ENDPOINT})")
|
|
27
|
+
parser.add_argument("--model", type=str, default=None,
|
|
28
|
+
help=f"Model to use from selected endpoint (default: {default_model})")
|
|
29
|
+
parser.add_argument("--port", type=int, default=DEFAULT_PORT,
|
|
30
|
+
help=(f"Port number to run Ziya frontend on "
|
|
31
|
+
f"(default: {DEFAULT_PORT}, e.g., --port 8080)"))
|
|
32
|
+
|
|
24
33
|
parser.add_argument("--version", action="store_true",
|
|
25
34
|
help="Prints the version of Ziya")
|
|
26
35
|
parser.add_argument("--max-depth", type=int, default=15,
|
|
27
36
|
help="Maximum depth for folder structure traversal (e.g., --max-depth 20)")
|
|
37
|
+
parser.add_argument("--check-auth", action="store_true",
|
|
38
|
+
help="Check authentication setup without starting the server")
|
|
28
39
|
return parser.parse_args()
|
|
29
40
|
|
|
30
41
|
|
|
@@ -36,8 +47,20 @@ def setup_environment(args):
|
|
|
36
47
|
|
|
37
48
|
if args.profile:
|
|
38
49
|
os.environ["ZIYA_AWS_PROFILE"] = args.profile
|
|
50
|
+
|
|
51
|
+
os.environ["ZIYA_ENDPOINT"] = args.endpoint
|
|
39
52
|
if args.model:
|
|
40
|
-
os.environ["
|
|
53
|
+
os.environ["ZIYA_MODEL"] = args.model
|
|
54
|
+
|
|
55
|
+
# If using Google endpoint, ensure credentials are available
|
|
56
|
+
if args.endpoint == "google" and not ModelManager._load_credentials():
|
|
57
|
+
logger.error(
|
|
58
|
+
"\nGOOGLE_API_KEY environment variable is required for google endpoint.\n"
|
|
59
|
+
"You can set it in your environment or create a .env file in either:\n"
|
|
60
|
+
" - Your current directory\n"
|
|
61
|
+
" - ~/.ziya/.env\n")
|
|
62
|
+
sys.exit(1)
|
|
63
|
+
|
|
41
64
|
os.environ["ZIYA_MAX_DEPTH"] = str(args.max_depth)
|
|
42
65
|
|
|
43
66
|
|
|
@@ -96,7 +119,38 @@ def start_server(args):
|
|
|
96
119
|
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
|
97
120
|
# Override the default server location from 127.0.0.1 to 0.0.0.0
|
|
98
121
|
# This allows the server to be accessible from other machines on the network
|
|
99
|
-
|
|
122
|
+
try:
|
|
123
|
+
# Pre-initialize the model to catch any credential issues before starting the server
|
|
124
|
+
logger.info("Performing initial authentication check...")
|
|
125
|
+
try:
|
|
126
|
+
# Try to initialize the model before starting the server
|
|
127
|
+
ModelManager.initialize_model()
|
|
128
|
+
logger.info("Authentication successful, starting server...")
|
|
129
|
+
serve(host="0.0.0.0", port=args.port)
|
|
130
|
+
except ValueError as e:
|
|
131
|
+
logger.error(f"\n{str(e)}")
|
|
132
|
+
logger.error("Server startup aborted due to configuration error.")
|
|
133
|
+
sys.exit(1)
|
|
134
|
+
except ValueError as e:
|
|
135
|
+
logger.error(f"\n{str(e)}")
|
|
136
|
+
logger.error("Server startup aborted due to configuration error.")
|
|
137
|
+
sys.exit(1)
|
|
138
|
+
|
|
139
|
+
def check_auth(args):
|
|
140
|
+
"""Check authentication setup without starting the server."""
|
|
141
|
+
try:
|
|
142
|
+
setup_environment(args)
|
|
143
|
+
# Only initialize if not already done
|
|
144
|
+
if not ModelManager._state['auth_checked'] or ModelManager._state['process_id'] != os.getpid():
|
|
145
|
+
model = ModelManager.initialize_model()
|
|
146
|
+
elif not ModelManager._state['auth_success']:
|
|
147
|
+
logger.error("Previous authentication attempt failed")
|
|
148
|
+
return False
|
|
149
|
+
logger.info("Authentication check successful!")
|
|
150
|
+
return True
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.error(f"Authentication check failed: {str(e)}")
|
|
153
|
+
return False
|
|
100
154
|
|
|
101
155
|
|
|
102
156
|
def main():
|
|
@@ -106,7 +160,16 @@ def main():
|
|
|
106
160
|
print_version()
|
|
107
161
|
return
|
|
108
162
|
|
|
109
|
-
|
|
163
|
+
if args.check_auth:
|
|
164
|
+
success = check_auth(args)
|
|
165
|
+
sys.exit(0 if success else 1)
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
check_version_and_upgrade()
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"Error checking version: {e}")
|
|
172
|
+
logger.warning("Continuing with current version...")
|
|
110
173
|
validate_langchain_vars()
|
|
111
174
|
setup_environment(args)
|
|
112
175
|
start_server(args)
|