ziya 0.1.49__py3-none-any.whl → 0.1.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ziya might be problematic. Click here for more details.
- app/agents/.agent.py.swp +0 -0
- app/agents/agent.py +315 -113
- app/agents/models.py +439 -0
- app/agents/prompts.py +32 -4
- app/main.py +70 -7
- app/server.py +403 -14
- app/utils/code_util.py +641 -215
- pyproject.toml +2 -3
- templates/asset-manifest.json +18 -20
- templates/index.html +1 -1
- templates/static/css/{main.87f30840.css → main.2bddf34e.css} +2 -2
- templates/static/css/main.2bddf34e.css.map +1 -0
- templates/static/js/46907.90c6a4f3.chunk.js +2 -0
- templates/static/js/46907.90c6a4f3.chunk.js.map +1 -0
- templates/static/js/56122.1d6a5c10.chunk.js +3 -0
- templates/static/js/56122.1d6a5c10.chunk.js.LICENSE.txt +9 -0
- templates/static/js/56122.1d6a5c10.chunk.js.map +1 -0
- templates/static/js/83953.61a908f4.chunk.js +3 -0
- templates/static/js/83953.61a908f4.chunk.js.map +1 -0
- templates/static/js/88261.1e90079d.chunk.js +3 -0
- templates/static/js/88261.1e90079d.chunk.js.map +1 -0
- templates/static/js/{96603.863a8f96.chunk.js → 96603.18c5d644.chunk.js} +2 -2
- templates/static/js/{96603.863a8f96.chunk.js.map → 96603.18c5d644.chunk.js.map} +1 -1
- templates/static/js/{97902.75670155.chunk.js → 97902.d1e262d6.chunk.js} +3 -3
- templates/static/js/{97902.75670155.chunk.js.map → 97902.d1e262d6.chunk.js.map} +1 -1
- templates/static/js/main.9b2b2b57.js +3 -0
- templates/static/js/{main.ee8b3c96.js.LICENSE.txt → main.9b2b2b57.js.LICENSE.txt} +8 -2
- templates/static/js/main.9b2b2b57.js.map +1 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/METADATA +4 -5
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/RECORD +36 -35
- templates/static/css/main.87f30840.css.map +0 -1
- templates/static/js/23416.c33f07ab.chunk.js +0 -3
- templates/static/js/23416.c33f07ab.chunk.js.map +0 -1
- templates/static/js/3799.fedb612f.chunk.js +0 -2
- templates/static/js/3799.fedb612f.chunk.js.map +0 -1
- templates/static/js/46907.4a730107.chunk.js +0 -2
- templates/static/js/46907.4a730107.chunk.js.map +0 -1
- templates/static/js/64754.cf383335.chunk.js +0 -2
- templates/static/js/64754.cf383335.chunk.js.map +0 -1
- templates/static/js/88261.33450351.chunk.js +0 -3
- templates/static/js/88261.33450351.chunk.js.map +0 -1
- templates/static/js/main.ee8b3c96.js +0 -3
- templates/static/js/main.ee8b3c96.js.map +0 -1
- /templates/static/js/{23416.c33f07ab.chunk.js.LICENSE.txt → 83953.61a908f4.chunk.js.LICENSE.txt} +0 -0
- /templates/static/js/{88261.33450351.chunk.js.LICENSE.txt → 88261.1e90079d.chunk.js.LICENSE.txt} +0 -0
- /templates/static/js/{97902.75670155.chunk.js.LICENSE.txt → 97902.d1e262d6.chunk.js.LICENSE.txt} +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/LICENSE +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/WHEEL +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/entry_points.txt +0 -0
app/server.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
3
|
import json
|
|
4
|
-
from typing import Dict, Any, List, Tuple, Optional
|
|
4
|
+
from typing import Dict, Any, List, Tuple, Optional, Union
|
|
5
5
|
|
|
6
6
|
import tiktoken
|
|
7
7
|
from fastapi import FastAPI, Request, HTTPException
|
|
@@ -10,15 +10,20 @@ from fastapi.responses import JSONResponse
|
|
|
10
10
|
from fastapi.staticfiles import StaticFiles
|
|
11
11
|
from fastapi.templating import Jinja2Templates
|
|
12
12
|
from langserve import add_routes
|
|
13
|
-
from app.agents.agent import model
|
|
13
|
+
from app.agents.agent import model, RetryingChatBedrock
|
|
14
14
|
from app.agents.agent import agent_executor
|
|
15
|
-
from
|
|
16
|
-
from
|
|
15
|
+
from app.agents.agent import update_conversation_state, update_and_return
|
|
16
|
+
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
|
|
17
|
+
from fastapi.responses import FileResponse, StreamingResponse
|
|
18
|
+
from pydantic import BaseModel, Field
|
|
19
|
+
from app.agents.models import ModelManager
|
|
17
20
|
from botocore.exceptions import ClientError, BotoCoreError, CredentialRetrievalError
|
|
18
21
|
from botocore.exceptions import EventStreamError
|
|
22
|
+
import botocore.errorfactory
|
|
19
23
|
from starlette.responses import StreamingResponse
|
|
20
24
|
|
|
21
25
|
# import pydevd_pycharm
|
|
26
|
+
from google.api_core.exceptions import ResourceExhausted
|
|
22
27
|
import uvicorn
|
|
23
28
|
|
|
24
29
|
from app.utils.code_util import use_git_to_apply_code_diff, correct_git_diff, PatchApplicationError
|
|
@@ -26,6 +31,13 @@ from app.utils.directory_util import get_ignored_patterns
|
|
|
26
31
|
from app.utils.logging_utils import logger
|
|
27
32
|
from app.utils.gitignore_parser import parse_gitignore_patterns
|
|
28
33
|
|
|
34
|
+
# Server configuration defaults
|
|
35
|
+
DEFAULT_PORT = 6969
|
|
36
|
+
# For model configurations, see app/agents/model.py
|
|
37
|
+
|
|
38
|
+
class SetModelRequest(BaseModel):
|
|
39
|
+
model_id: str
|
|
40
|
+
|
|
29
41
|
app = FastAPI()
|
|
30
42
|
|
|
31
43
|
app.add_middleware(
|
|
@@ -74,7 +86,6 @@ async def credential_exception_handler(request: Request, exc: CredentialRetrieva
|
|
|
74
86
|
headers={"WWW-Authenticate": "Bearer"}
|
|
75
87
|
)
|
|
76
88
|
|
|
77
|
-
|
|
78
89
|
@app.exception_handler(ClientError)
|
|
79
90
|
async def boto_client_exception_handler(request: Request, exc: ClientError):
|
|
80
91
|
error_message = str(exc)
|
|
@@ -84,6 +95,13 @@ async def boto_client_exception_handler(request: Request, exc: ClientError):
|
|
|
84
95
|
content={"detail": "AWS credentials have expired. Please refresh your credentials."},
|
|
85
96
|
headers={"WWW-Authenticate": "Bearer"}
|
|
86
97
|
)
|
|
98
|
+
elif "ValidationException" in error_message:
|
|
99
|
+
logger.error(f"Bedrock validation error: {error_message}")
|
|
100
|
+
return JSONResponse(
|
|
101
|
+
status_code=400,
|
|
102
|
+
content={"error": "validation_error",
|
|
103
|
+
"detail": "Invalid request format for Bedrock service. Please check your input format.",
|
|
104
|
+
"message": error_message})
|
|
87
105
|
elif "ServiceUnavailableException" in error_message:
|
|
88
106
|
return JSONResponse(
|
|
89
107
|
status_code=503,
|
|
@@ -94,11 +112,77 @@ async def boto_client_exception_handler(request: Request, exc: ClientError):
|
|
|
94
112
|
content={"detail": f"AWS Service Error: {str(exc)}"}
|
|
95
113
|
)
|
|
96
114
|
|
|
115
|
+
@app.exception_handler(ResourceExhausted)
|
|
116
|
+
async def resource_exhausted_handler(request: Request, exc: ResourceExhausted):
|
|
117
|
+
"""Handle Google API quota exceeded errors."""
|
|
118
|
+
logger.error(f"Google API quota exceeded: {str(exc)}")
|
|
119
|
+
return JSONResponse(
|
|
120
|
+
status_code=429, # Too Many Requests
|
|
121
|
+
content={
|
|
122
|
+
"error": "quota_exceeded",
|
|
123
|
+
"detail": "API quota has been exceeded. Please try again in a few minutes.",
|
|
124
|
+
"original_error": str(exc)
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
@app.exception_handler(ResourceExhausted)
|
|
129
|
+
async def resource_exhausted_handler(request: Request, exc: ResourceExhausted):
|
|
130
|
+
"""Handle Google API quota exceeded errors."""
|
|
131
|
+
logger.error(f"Google API quota exceeded: {str(exc)}")
|
|
132
|
+
return JSONResponse(
|
|
133
|
+
status_code=429, # Too Many Requests
|
|
134
|
+
content={
|
|
135
|
+
"error": "quota_exceeded",
|
|
136
|
+
"detail": "API quota has been exceeded. Please try again in a few minutes.",
|
|
137
|
+
"original_error": str(exc)
|
|
138
|
+
}
|
|
139
|
+
)
|
|
140
|
+
|
|
97
141
|
@app.exception_handler(Exception)
|
|
98
142
|
async def general_exception_handler(request: Request, exc: Exception):
|
|
99
143
|
error_message = str(exc)
|
|
100
144
|
status_code = 500
|
|
101
145
|
error_type = "unknown_error"
|
|
146
|
+
|
|
147
|
+
# Check for empty text parameter error from Gemini
|
|
148
|
+
if "Unable to submit request because it has an empty text parameter" in error_message:
|
|
149
|
+
logger.error("Caught empty text parameter error from Gemini")
|
|
150
|
+
return JSONResponse(
|
|
151
|
+
status_code=400,
|
|
152
|
+
content={
|
|
153
|
+
"error": "validation_error",
|
|
154
|
+
"detail": "Empty message content detected. Please provide a question."
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Check for Google API quota exceeded error
|
|
159
|
+
if "Resource has been exhausted" in error_message and "check quota" in error_message:
|
|
160
|
+
return JSONResponse(
|
|
161
|
+
status_code=429, # Too Many Requests
|
|
162
|
+
content={
|
|
163
|
+
"error": "quota_exceeded",
|
|
164
|
+
"detail": "API quota has been exceeded. Please try again in a few minutes."
|
|
165
|
+
})
|
|
166
|
+
|
|
167
|
+
# Check for Gemini token limit error
|
|
168
|
+
if isinstance(exc, ChatGoogleGenerativeAIError) and "token count" in error_message:
|
|
169
|
+
return JSONResponse(
|
|
170
|
+
status_code=413,
|
|
171
|
+
content={
|
|
172
|
+
"error": "validation_error",
|
|
173
|
+
"detail": "Selected content is too large for the model. Please reduce the number of files."
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Check for Google API quota exceeded error
|
|
178
|
+
if "Resource has been exhausted" in error_message and "check quota" in error_message:
|
|
179
|
+
return JSONResponse(
|
|
180
|
+
status_code=429, # Too Many Requests
|
|
181
|
+
content={
|
|
182
|
+
"error": "quota_exceeded",
|
|
183
|
+
"detail": "API quota has been exceeded. Please try again in a few minutes."
|
|
184
|
+
})
|
|
185
|
+
|
|
102
186
|
try:
|
|
103
187
|
# Check if this is a streaming error
|
|
104
188
|
if isinstance(exc, EventStreamError):
|
|
@@ -124,7 +208,21 @@ async def general_exception_handler(request: Request, exc: Exception):
|
|
|
124
208
|
logger.error(f"Error in exception handler: {str(e)}", exc_info=True)
|
|
125
209
|
raise
|
|
126
210
|
|
|
127
|
-
|
|
211
|
+
# Get the absolute path to the project root directory
|
|
212
|
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
213
|
+
|
|
214
|
+
# Define paths relative to project root
|
|
215
|
+
static_dir = os.path.join(project_root, "templates", "static")
|
|
216
|
+
testcases_dir = os.path.join(project_root, "tests", "frontend", "testcases")
|
|
217
|
+
templates_dir = os.path.join(project_root, "templates")
|
|
218
|
+
|
|
219
|
+
# Create directories if they don't exist
|
|
220
|
+
os.makedirs(static_dir, exist_ok=True)
|
|
221
|
+
os.makedirs(testcases_dir, exist_ok=True)
|
|
222
|
+
os.makedirs(templates_dir, exist_ok=True)
|
|
223
|
+
|
|
224
|
+
# Mount static files and templates
|
|
225
|
+
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
128
226
|
|
|
129
227
|
# Only mount testcases directory if it exists
|
|
130
228
|
testcases_dir = "../tests/frontend/testcases"
|
|
@@ -133,27 +231,109 @@ if os.path.exists(testcases_dir):
|
|
|
133
231
|
else:
|
|
134
232
|
logger.info(f"Testcases directory '{testcases_dir}' does not exist - skipping mount")
|
|
135
233
|
|
|
136
|
-
templates = Jinja2Templates(directory=
|
|
234
|
+
templates = Jinja2Templates(directory=templates_dir)
|
|
137
235
|
|
|
138
236
|
# Add a route for the frontend
|
|
139
237
|
add_routes(app, agent_executor, disabled_endpoints=["playground"], path="/ziya")
|
|
140
238
|
# Override the stream endpoint with our error handling
|
|
141
239
|
@app.post("/ziya/stream")
|
|
142
240
|
async def stream_endpoint(body: dict):
|
|
241
|
+
|
|
242
|
+
# Debug logging
|
|
243
|
+
logger.info("Stream endpoint request body:")
|
|
244
|
+
logger.info(f"Question: '{body.get('question', 'EMPTY')}'")
|
|
245
|
+
logger.info(f"Chat history length: {len(body.get('chat_history', []))}")
|
|
246
|
+
logger.info(f"Files count: {len(body.get('config', {}).get('files', []))}")
|
|
247
|
+
logger.info(f"Question type: {type(body.get('question', None))}")
|
|
248
|
+
|
|
249
|
+
# Log the first few files
|
|
250
|
+
if 'config' in body and 'files' in body['config']:
|
|
251
|
+
logger.info(f"First few files: {body['config']['files'][:5]}")
|
|
252
|
+
|
|
253
|
+
# Check if the question is empty or missing
|
|
254
|
+
if not body.get("question") or not body.get("question").strip():
|
|
255
|
+
logger.warning("Empty question detected, returning error response")
|
|
256
|
+
error_response = json.dumps({
|
|
257
|
+
"error": "validation_error",
|
|
258
|
+
"detail": "Please provide a question to continue."
|
|
259
|
+
})
|
|
260
|
+
|
|
261
|
+
# Return a properly formatted SSE response with the error
|
|
262
|
+
async def error_stream():
|
|
263
|
+
# Send the error message
|
|
264
|
+
yield f"data: {error_response}\n\n"
|
|
265
|
+
# Wait a moment to ensure the client receives it
|
|
266
|
+
await asyncio.sleep(0.1)
|
|
267
|
+
# Send an end message
|
|
268
|
+
yield "data: [DONE]\n\n"
|
|
269
|
+
|
|
270
|
+
return StreamingResponse(
|
|
271
|
+
error_stream(),
|
|
272
|
+
media_type="text/event-stream",
|
|
273
|
+
headers={"Cache-Control": "no-cache"}
|
|
274
|
+
)
|
|
143
275
|
try:
|
|
276
|
+
# Check for empty question
|
|
277
|
+
if not body.get("question") or not body.get("question").strip():
|
|
278
|
+
logger.warning("Empty question detected in stream request")
|
|
279
|
+
# Return a friendly error message
|
|
280
|
+
return StreamingResponse(
|
|
281
|
+
iter([f'data: {json.dumps({"error": "validation_error", "detail": "Please enter a question"})}' + '\n\n']),
|
|
282
|
+
media_type="text/event-stream",
|
|
283
|
+
headers={"Cache-Control": "no-cache"}
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
# Check for empty messages in chat history
|
|
287
|
+
if "chat_history" in body:
|
|
288
|
+
cleaned_history = []
|
|
289
|
+
for pair in body["chat_history"]:
|
|
290
|
+
try:
|
|
291
|
+
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
|
|
292
|
+
logger.warning(f"Invalid chat history pair format: {pair}")
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
human, ai = pair
|
|
296
|
+
if not isinstance(human, str) or not isinstance(ai, str):
|
|
297
|
+
logger.warning(f"Non-string message in pair: human={type(human)}, ai={type(ai)}")
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
if human.strip() and ai.strip():
|
|
301
|
+
cleaned_history.append((human.strip(), ai.strip()))
|
|
302
|
+
else:
|
|
303
|
+
logger.warning(f"Empty message in pair: {pair}")
|
|
304
|
+
except Exception as e:
|
|
305
|
+
logger.error(f"Error processing chat history pair: {str(e)}")
|
|
306
|
+
|
|
307
|
+
logger.debug(f"Cleaned chat history from {len(body['chat_history'])} to {len(cleaned_history)} pairs")
|
|
308
|
+
body["chat_history"] = cleaned_history
|
|
309
|
+
logger.debug(f"Cleaned chat history: {json.dumps(cleaned_history)}")
|
|
310
|
+
|
|
144
311
|
logger.info("Starting stream endpoint with body size: %d", len(str(body)))
|
|
145
312
|
# Define the streaming response with proper error handling
|
|
146
313
|
async def error_handled_stream():
|
|
314
|
+
response = None
|
|
147
315
|
try:
|
|
316
|
+
# Convert to ChatPromptValue before streaming
|
|
317
|
+
if isinstance(body, dict) and "messages" in body:
|
|
318
|
+
from langchain_core.prompt_values import ChatPromptValue
|
|
319
|
+
from langchain_core.messages import HumanMessage
|
|
320
|
+
body["messages"] = [HumanMessage(content=msg) for msg in body["messages"]]
|
|
321
|
+
body = ChatPromptValue(messages=body["messages"])
|
|
148
322
|
# Create the iterator inside the error handling context
|
|
149
|
-
iterator = agent_executor.astream_log(body
|
|
323
|
+
iterator = agent_executor.astream_log(body)
|
|
150
324
|
async for chunk in iterator:
|
|
151
325
|
logger.info("Processing chunk: %s",
|
|
152
326
|
chunk if isinstance(chunk, dict) else chunk[:200] + "..." if len(chunk) > 200 else chunk)
|
|
153
327
|
if isinstance(chunk, dict) and "error" in chunk:
|
|
154
328
|
# Format error as SSE message
|
|
155
329
|
yield f"data: {json.dumps(chunk)}\n\n"
|
|
156
|
-
|
|
330
|
+
# Update file state before returning
|
|
331
|
+
update_and_return(body)
|
|
332
|
+
logger.info(f"Sent error message: {chunk}")
|
|
333
|
+
return
|
|
334
|
+
elif isinstance(chunk, Generation) and hasattr(chunk, 'text') and "quota_exceeded" in chunk.text:
|
|
335
|
+
yield f"data: {chunk.text}\n\n"
|
|
336
|
+
update_and_return(body)
|
|
157
337
|
return
|
|
158
338
|
else:
|
|
159
339
|
try:
|
|
@@ -166,9 +346,30 @@ async def stream_endpoint(body: dict):
|
|
|
166
346
|
"detail": "Selected content is too large for the model. Please reduce the number of files."
|
|
167
347
|
}
|
|
168
348
|
yield f"data: {json.dumps(error_msg)}\n\n"
|
|
349
|
+
update_and_return(body)
|
|
169
350
|
await response.flush()
|
|
170
351
|
logger.info("Sent EventStreamError message: %s", error_msg)
|
|
171
352
|
return
|
|
353
|
+
except ChatGoogleGenerativeAIError as e:
|
|
354
|
+
if "token count" in str(e):
|
|
355
|
+
error_msg = {
|
|
356
|
+
"error": "validation_error",
|
|
357
|
+
"detail": "Selected content is too large for the model. Please reduce the number of files."
|
|
358
|
+
}
|
|
359
|
+
yield f"data: {json.dumps(error_msg)}\n\n"
|
|
360
|
+
update_and_return(body)
|
|
361
|
+
await response.flush()
|
|
362
|
+
logger.info("Sent token limit error message: %s", error_msg)
|
|
363
|
+
return
|
|
364
|
+
except ResourceExhausted as e:
|
|
365
|
+
error_msg = {
|
|
366
|
+
"error": "quota_exceeded",
|
|
367
|
+
"detail": "API quota has been exceeded. Please try again in a few minutes."
|
|
368
|
+
}
|
|
369
|
+
yield f"data: {json.dumps(error_msg)}\n\n"
|
|
370
|
+
update_and_return(body)
|
|
371
|
+
logger.error(f"Caught ResourceExhausted error: {str(e)}")
|
|
372
|
+
return
|
|
172
373
|
except EventStreamError as e:
|
|
173
374
|
if "validationException" in str(e):
|
|
174
375
|
error_msg = {
|
|
@@ -176,15 +377,18 @@ async def stream_endpoint(body: dict):
|
|
|
176
377
|
"detail": "Selected content is too large for the model. Please reduce the number of files."
|
|
177
378
|
}
|
|
178
379
|
yield f"data: {json.dumps(error_msg)}\n\n"
|
|
380
|
+
update_and_return(body)
|
|
179
381
|
await response.flush()
|
|
180
382
|
return
|
|
181
383
|
raise
|
|
384
|
+
finally:
|
|
385
|
+
update_and_return(body)
|
|
182
386
|
return StreamingResponse(error_handled_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
|
|
183
387
|
except Exception as e:
|
|
184
388
|
logger.error(f"Error in stream endpoint: {str(e)}")
|
|
185
389
|
error_msg = {"error": "stream_error", "detail": str(e)}
|
|
186
390
|
logger.error(f"Sending error response: {error_msg}")
|
|
187
|
-
|
|
391
|
+
update_and_return(body)
|
|
188
392
|
return StreamingResponse(iter([f"data: {json.dumps(error_msg)}\n\n"]), media_type="text/event-stream", headers={"Cache-Control": "no-cache"})
|
|
189
393
|
|
|
190
394
|
|
|
@@ -287,24 +491,145 @@ async def get_folders():
|
|
|
287
491
|
return get_cached_folder_structure(user_codebase_dir, ignored_patterns, max_depth)
|
|
288
492
|
|
|
289
493
|
@app.get('/api/default-included-folders')
|
|
290
|
-
def
|
|
494
|
+
def get_model_id():
|
|
291
495
|
return {'defaultIncludedFolders': []}
|
|
292
496
|
|
|
497
|
+
@app.get('/api/current-model')
|
|
498
|
+
def get_current_model():
|
|
499
|
+
"""Get detailed information about the currently active model."""
|
|
500
|
+
logger.info(
|
|
501
|
+
"Current model info request: %s",
|
|
502
|
+
{ 'model_id': model.model_id,
|
|
503
|
+
'endpoint': os.environ.get("ZIYA_ENDPOINT", "bedrock")
|
|
504
|
+
})
|
|
505
|
+
|
|
506
|
+
# Get actual model settings
|
|
507
|
+
model_kwargs = {}
|
|
508
|
+
if hasattr(model, 'model') and hasattr(model.model, 'model_kwargs'):
|
|
509
|
+
model_kwargs = model.model.model_kwargs
|
|
510
|
+
elif hasattr(model, 'model_kwargs'):
|
|
511
|
+
model_kwargs = model.model_kwargs
|
|
512
|
+
|
|
513
|
+
logger.info("Current model configuration:")
|
|
514
|
+
logger.info(f" Model ID: {model.model_id}")
|
|
515
|
+
logger.info(f" Temperature: {model_kwargs.get('temperature', 'Not set')} (env: {os.environ.get('ZIYA_TEMPERATURE', 'Not set')})")
|
|
516
|
+
logger.info(f" Top K: {model_kwargs.get('top_k', 'Not set')} (env: {os.environ.get('ZIYA_TOP_K', 'Not set')})")
|
|
517
|
+
logger.info(f" Max tokens: {model_kwargs.get('max_tokens', 'Not set')} (env: {os.environ.get('ZIYA_MAX_OUTPUT_TOKENS', 'Not set')})")
|
|
518
|
+
logger.info(f" Thinking mode: {os.environ.get('ZIYA_THINKING_MODE', 'Not set')}")
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
return {
|
|
522
|
+
'model_id': model.model_id,
|
|
523
|
+
'endpoint': os.environ.get("ZIYA_ENDPOINT", "bedrock"),
|
|
524
|
+
'settings': {
|
|
525
|
+
'temperature': model_kwargs.get('temperature',
|
|
526
|
+
float(os.environ.get("ZIYA_TEMPERATURE", 0.3))),
|
|
527
|
+
'max_output_tokens': model_kwargs.get('max_tokens',
|
|
528
|
+
int(os.environ.get("ZIYA_MAX_OUTPUT_TOKENS", 4096))),
|
|
529
|
+
'top_k': model_kwargs.get('top_k',
|
|
530
|
+
int(os.environ.get("ZIYA_TOP_K", 15))),
|
|
531
|
+
'thinking_mode': os.environ.get("ZIYA_THINKING_MODE") == "1"
|
|
532
|
+
|
|
533
|
+
}
|
|
534
|
+
}
|
|
535
|
+
|
|
293
536
|
@app.get('/api/model-id')
|
|
294
537
|
def get_model_id():
|
|
295
|
-
|
|
296
|
-
|
|
538
|
+
if os.environ.get("ZIYA_ENDPOINT") == "google":
|
|
539
|
+
model_name = os.environ.get("ZIYA_MODEL", "gemini-pro")
|
|
540
|
+
return {'model_id': model_name}
|
|
541
|
+
elif os.environ.get("ZIYA_MODEL"):
|
|
542
|
+
return {'model_id': os.environ.get("ZIYA_MODEL")}
|
|
543
|
+
else:
|
|
544
|
+
# Bedrock
|
|
545
|
+
return {'model_id': model.model_id.split(':')[0].split('/')[-1]}
|
|
546
|
+
|
|
547
|
+
@app.post('/api/set-model')
|
|
548
|
+
async def set_model(request: SetModelRequest):
|
|
549
|
+
"""Set the active model for the current endpoint."""
|
|
550
|
+
try:
|
|
551
|
+
model_id = request.model_id
|
|
552
|
+
if not model_id:
|
|
553
|
+
logger.error("Empty model ID provided")
|
|
554
|
+
raise HTTPException(status_code=400, detail="Model ID is required")
|
|
555
|
+
|
|
556
|
+
# Update environment variable
|
|
557
|
+
os.environ["ZIYA_MODEL"] = model_id
|
|
558
|
+
logger.info(f"Setting model to: {model_id}")
|
|
559
|
+
|
|
560
|
+
# Reinitialize the model
|
|
561
|
+
try:
|
|
562
|
+
logger.info(f"Reinitializing model with ID: {model_id}")
|
|
563
|
+
new_model = ModelManager.initialize_model(force_reinit=True)
|
|
564
|
+
new_model.model_id = model_id # Ensure model ID is set correctly
|
|
565
|
+
|
|
566
|
+
# Update the global model instance
|
|
567
|
+
global model
|
|
568
|
+
model = RetryingChatBedrock(new_model)
|
|
569
|
+
|
|
570
|
+
return {"status": "success", "model": model_id}
|
|
571
|
+
except Exception as e:
|
|
572
|
+
logger.error(f"Failed to initialize model: {str(e)}")
|
|
573
|
+
raise HTTPException(status_code=500, detail=f"Failed to initialize model: {str(e)}")
|
|
574
|
+
|
|
575
|
+
except Exception as e:
|
|
576
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
577
|
+
|
|
578
|
+
@app.get('/api/available-models')
|
|
579
|
+
def get_available_models():
|
|
580
|
+
"""Get list of available models for the current endpoint."""
|
|
581
|
+
endpoint = os.environ.get("ZIYA_ENDPOINT", "bedrock")
|
|
582
|
+
try:
|
|
583
|
+
models = []
|
|
584
|
+
for name, config in ModelManager.MODEL_CONFIGS[endpoint].items():
|
|
585
|
+
models.append({
|
|
586
|
+
"id": config["model_id"],
|
|
587
|
+
"name": name
|
|
588
|
+
})
|
|
589
|
+
return models
|
|
590
|
+
except Exception as e:
|
|
591
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
592
|
+
return {'model_id': model.model_id.split(':')[0].split('/')[-1]}
|
|
593
|
+
|
|
594
|
+
@app.get('/api/model-capabilities')
|
|
595
|
+
def get_model_capabilities(model: str = None):
|
|
596
|
+
"""Get the capabilities of the current model."""
|
|
597
|
+
endpoint = os.environ.get("ZIYA_ENDPOINT", "bedrock")
|
|
598
|
+
# If model parameter is provided, get capabilities for that model
|
|
599
|
+
# Otherwise use current model
|
|
600
|
+
model_name = model if model else os.environ.get("ZIYA_MODEL")
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
model_config = ModelManager.get_model_config(endpoint, model_name)
|
|
604
|
+
capabilities = {
|
|
605
|
+
"supports_thinking": model_config.get("supports_thinking", False),
|
|
606
|
+
"max_output_tokens": model_config.get("max_output_tokens", 4096),
|
|
607
|
+
"temperature_range": {"min": 0, "max": 1, "default": model_config.get("temperature", 0.3)},
|
|
608
|
+
"top_k_range": {"min": 0, "max": 500, "default": model_config.get("top_k", 15)} if endpoint == "bedrock" else None
|
|
609
|
+
}
|
|
610
|
+
return capabilities
|
|
611
|
+
except Exception as e:
|
|
612
|
+
logger.error(f"Error getting model capabilities: {str(e)}")
|
|
613
|
+
return {"error": str(e)}
|
|
297
614
|
|
|
298
615
|
class ApplyChangesRequest(BaseModel):
|
|
299
616
|
diff: str
|
|
300
617
|
filePath: str
|
|
301
618
|
|
|
619
|
+
class ModelSettingsRequest(BaseModel):
|
|
620
|
+
temperature: float = Field(default=0.3, ge=0, le=1)
|
|
621
|
+
top_k: int = Field(default=15, ge=0, le=500)
|
|
622
|
+
max_output_tokens: int = Field(default=4096, ge=1, le=128000)
|
|
623
|
+
thinking_mode: bool = Field(default=False)
|
|
624
|
+
|
|
625
|
+
|
|
302
626
|
class TokenCountRequest(BaseModel):
|
|
303
627
|
text: str
|
|
304
628
|
|
|
305
629
|
def count_tokens_fallback(text: str) -> int:
|
|
306
630
|
"""Fallback methods for counting tokens when primary method fails."""
|
|
307
631
|
try:
|
|
632
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
308
633
|
# First try using tiktoken directly with cl100k_base (used by Claude)
|
|
309
634
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
310
635
|
return len(encoding.encode(text))
|
|
@@ -346,6 +671,70 @@ async def count_tokens(request: TokenCountRequest) -> Dict[str, int]:
|
|
|
346
671
|
# Return 0 in case of error to avoid breaking the frontend
|
|
347
672
|
return {"token_count": 0}
|
|
348
673
|
|
|
674
|
+
@app.post('/api/model-settings')
|
|
675
|
+
async def update_model_settings(settings: ModelSettingsRequest):
|
|
676
|
+
global model
|
|
677
|
+
try:
|
|
678
|
+
# Log the requested settings
|
|
679
|
+
logger.info(f"Requested model settings update:")
|
|
680
|
+
logger.info(f" Temperature: {settings.temperature}")
|
|
681
|
+
logger.info(f" Top K: {settings.top_k}")
|
|
682
|
+
logger.info(f" Max Output Tokens: {settings.max_output_tokens}")
|
|
683
|
+
logger.info(f" Thinking Mode: {settings.thinking_mode}")
|
|
684
|
+
|
|
685
|
+
# Store settings in environment variables for the agent to use
|
|
686
|
+
os.environ["ZIYA_TEMPERATURE"] = str(settings.temperature)
|
|
687
|
+
os.environ["ZIYA_TOP_K"] = str(settings.top_k)
|
|
688
|
+
os.environ["ZIYA_MAX_OUTPUT_TOKENS"] = str(settings.max_output_tokens)
|
|
689
|
+
os.environ["ZIYA_THINKING_MODE"] = "1" if settings.thinking_mode else "0"
|
|
690
|
+
|
|
691
|
+
# Update the model's kwargs directly
|
|
692
|
+
if hasattr(model, 'model'):
|
|
693
|
+
# For wrapped models (e.g., RetryingChatBedrock)
|
|
694
|
+
if hasattr(model.model, 'model_kwargs'):
|
|
695
|
+
model.model.model_kwargs.update({
|
|
696
|
+
'temperature': settings.temperature,
|
|
697
|
+
'top_k': settings.top_k,
|
|
698
|
+
'max_tokens': settings.max_output_tokens
|
|
699
|
+
})
|
|
700
|
+
elif hasattr(model, 'model_kwargs'):
|
|
701
|
+
# For direct model instances
|
|
702
|
+
model.model_kwargs.update({
|
|
703
|
+
'temperature': settings.temperature,
|
|
704
|
+
'top_k': settings.top_k,
|
|
705
|
+
'max_tokens': settings.max_output_tokens
|
|
706
|
+
})
|
|
707
|
+
|
|
708
|
+
# Force model reinitialization to apply new settings
|
|
709
|
+
from app.agents.models import ModelManager
|
|
710
|
+
model = ModelManager.initialize_model(force_reinit=True)
|
|
711
|
+
model.model_id = os.environ.get("ZIYA_MODEL", model.model_id)
|
|
712
|
+
|
|
713
|
+
# Get the model's current settings for verification
|
|
714
|
+
model_kwargs = {}
|
|
715
|
+
if hasattr(model, 'model') and hasattr(model.model, 'model_kwargs'):
|
|
716
|
+
model_kwargs = model.model.model_kwargs
|
|
717
|
+
elif hasattr(model, 'model_kwargs'):
|
|
718
|
+
model_kwargs = model.model_kwargs
|
|
719
|
+
|
|
720
|
+
logger.info("Current model settings after update:")
|
|
721
|
+
logger.info(f" Model kwargs temperature: {model_kwargs.get('temperature', 'Not set')}")
|
|
722
|
+
logger.info(f" Model kwargs top_k: {model_kwargs.get('top_k', 'Not set')}")
|
|
723
|
+
logger.info(f" Model kwargs max_tokens: {model_kwargs.get('max_tokens', 'Not set')}")
|
|
724
|
+
logger.info(f" Environment ZIYA_THINKING_MODE: {os.environ.get('ZIYA_THINKING_MODE')}")
|
|
725
|
+
|
|
726
|
+
return {
|
|
727
|
+
'status': 'success',
|
|
728
|
+
'message': 'Model settings updated',
|
|
729
|
+
'settings': model_kwargs
|
|
730
|
+
}
|
|
731
|
+
except Exception as e:
|
|
732
|
+
logger.error(f"Error updating model settings: {str(e)}", exc_info=True)
|
|
733
|
+
raise HTTPException(
|
|
734
|
+
status_code=500,
|
|
735
|
+
detail=f"Error updating model settings: {str(e)}"
|
|
736
|
+
)
|
|
737
|
+
|
|
349
738
|
@app.post('/api/apply-changes')
|
|
350
739
|
async def apply_changes(request: ApplyChangesRequest):
|
|
351
740
|
try:
|
|
@@ -405,4 +794,4 @@ async def apply_changes(request: ApplyChangesRequest):
|
|
|
405
794
|
)
|
|
406
795
|
|
|
407
796
|
if __name__ == "__main__":
|
|
408
|
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
797
|
+
uvicorn.run(app, host="0.0.0.0", port=DEFAULT_PORT)
|